import copy
from knowledge_tracing.args import ARGS
from datasets.dataset_parser import Constants
import torch
import torch.nn as nn
import numpy as np


def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


def shift_right(tensor, n, pad_value=0, dim=-1):
    '''
    Args:
        tensor: Tensor to shift
        n: Distance to shift by
        pad_value: Symbol used to fill the empty spaces left by shifting
        dim: shifting dimension
    Returns:
        A tensor of the same shape as tensor with components shifted
        n units to the right along dimension dim.
    '''
    assert 0 <= n <= tensor.shape[dim]
    pad_shape = list(tensor.shape)
    pad_shape[dim] = n
    padding = tensor.new_full(pad_shape, pad_value)
    shifted = torch.narrow(tensor, dim, 0, tensor.shape[dim] - n)
    return torch.cat([padding, shifted], dim=dim)


class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''
    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps
        ])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr


class NoamOpt:
    "Optim wrapper that implements rate."

    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0

    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step=None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
               (self.model_size ** (-0.5) *
                min(step ** (-0.5), step * self.warmup ** (-1.5)))


def get_constraint_losses(data, last_output):
    aug_losses = {}
    for aug in ARGS.augmentations:
        if aug == 'rep' and ARGS.rep_cons_loss:
            rep_loss = get_replacement_loss(data, last_output)
            if rep_loss is not None:
                aug_losses['rep'] = rep_loss
        elif aug == 'ins' and ARGS.ins_cons_loss:
            ins_loss = get_insertion_loss(data, last_output)
            if ins_loss is not None:
                aug_losses['ins'] = ins_loss
        elif aug == 'del' and ARGS.del_cons_loss:
            del_loss = get_deletion_loss(data, last_output)
            if del_loss is not None:
                aug_losses['del'] = del_loss
    return aug_losses


def get_replacement_loss(data, last_output):
    loss = torch.nn.MSELoss(reduction='none')
    rep_loss = None
    if not ARGS.rep_pred:
        loss_num = 0
        bsz = data['rep']['idx'].shape[0]
        for b in range(bsz):
            not_rep_num = (data['rep']['idx'][b] >= 0).sum().item()
            loss_num += not_rep_num
            not_rep_idx = data['rep']['idx'][b][:not_rep_num]
            filtered_ori_output = last_output['ori'][b][not_rep_idx]
            filtered_rep_output = last_output['rep'][b][not_rep_idx]
            batch_rep_loss = loss(filtered_ori_output, filtered_rep_output).sum()
            if rep_loss is None:
                rep_loss = batch_rep_loss
            else:
                rep_loss += batch_rep_loss
        if rep_loss is not None:
            rep_loss /= loss_num
    elif ARGS.rep_only:
        loss_num = 0
        bsz = data['rep']['idx'].shape[0]
        for b in range(bsz):
            not_rep_num = (data['rep']['idx'][b] >= 0).sum().item()
            not_rep_idx = data['rep']['idx'][b][:not_rep_num]
            seq_size = data['ori']['sequence_size'][b].item()
            rep_idx = list(set(range(seq_size)) - set(not_rep_idx.tolist()))
            rep_idx.sort()
            loss_num += (seq_size - not_rep_num)
            filtered_ori_output = last_output['ori'][b][rep_idx]
            filtered_rep_output = last_output['rep'][b][rep_idx]
            if not ARGS.rep_backprob:
                filtered_rep_output = filtered_rep_output.detach()
            batch_rep_loss = loss(filtered_ori_output, filtered_rep_output).sum()
            if rep_loss is None:
                rep_loss = batch_rep_loss
            else:
                rep_loss += batch_rep_loss
        if rep_loss is not None:
            rep_loss /= loss_num
    else:
        rep_loss = loss(last_output['rep'], last_output['ori'])  # (bsz, seq_len, 1)
        rep_loss = rep_loss.squeeze(-1).masked_fill(~data['rep']['loss_mask'], 0)
        loss_num = data['rep']['loss_mask'].sum().type_as(rep_loss)
        if loss_num > 0:
            rep_loss = rep_loss.sum() / data['rep']['loss_mask'].sum().type_as(rep_loss)

    return rep_loss


def get_insertion_loss(data, last_output):
    bsz = data['ins']['idx'].shape[0]
    loss_num = 0
    ins_loss = None
    for b in range(bsz):
        ori_len = data['ori']['sequence_size'][b][0].item()
        ins_len = data['ins']['sequence_size'][b][0].item()
        ins_num = ins_len - ori_len
        if ins_num > 0:
            loss_num += ori_len
            rem_idx = data['ins']['idx'][b][:ori_len]
            filtered_ori_output = last_output['ori'][b][:ori_len]
            filtered_ins_output = last_output['ins'][b][rem_idx]
            if ARGS.ins_loss_dir == 'up':
                batch_ins_loss = (filtered_ori_output - filtered_ins_output).relu()
            else:  # down
                batch_ins_loss = (filtered_ins_output - filtered_ori_output).relu()
            batch_ins_loss = batch_ins_loss.sum()
            if ins_loss is None:
                ins_loss = batch_ins_loss
            else:
                ins_loss += batch_ins_loss
    if ins_loss is not None:
        ins_loss /= loss_num

    return ins_loss


def get_deletion_loss(data, last_output):
    bsz = data['del']['idx'].shape[0]
    loss_num = 0
    del_loss = None
    for b in range(bsz):
        rem_num = data['del']['sequence_size'][b][0].item()
        if rem_num > 0:
            loss_num += rem_num
            rem_idx = data['del']['idx'][b][:rem_num]
            filtered_ori_output = last_output['ori'][b][rem_idx]
            filtered_del_output = last_output['del'][b][:rem_num]
            if ARGS.del_loss_dir == 'down':
                batch_del_loss = (filtered_del_output - filtered_ori_output).relu().sum()
            else:  # up
                batch_del_loss = (filtered_ori_output - filtered_del_output).relu().sum()
            if del_loss is None:
                del_loss = batch_del_loss
            else:
                del_loss += batch_del_loss
    if del_loss is not None:
        del_loss /= loss_num

    return del_loss


def get_question_similarity_matrix():
    data_constant = Constants(ARGS.dataset_name, ARGS.data_root)
    question_num = data_constant.NUM_ITEMS
    skill_num = data_constant.NUM_TAGS

    question_similarity_matrix = torch.zeros((question_num, question_num), dtype=torch.long)
    for s in range(1, skill_num+1):
        if s in data_constant.SID_TO_QIDS:
            questions = data_constant.SID_TO_QIDS[s]
            for i in range(len(questions)):
                for j in range(i+1, len(questions)):
                    q1 = questions[i]
                    q2 = questions[j]
                    question_similarity_matrix[q1-1][q2-1] = 1
                    question_similarity_matrix[q2-1][q1-1] = 1
    print("question_similarity_matrix generated")
    return question_similarity_matrix


def get_laplacian_loss(output, padding_mask, question_similarity_matrix):
    """
    Args:
        output: correctness probability predictions,
                float tensor of shape (batch_size, seq_size, question_num + 1)
        padding_mask: bool tensor that represents padding (padding == False)
    Return:
        Laplacian regularization loss, single float number (tensor)
    """
    # 0-index is for padding, we need to ignore it for loss computation
    question_num = output.shape[2] - 1

    diagonal_matrix = torch.diag(question_similarity_matrix.sum(-1))  # (question_num, question_num)
    laplacian_matrix = (diagonal_matrix - question_similarity_matrix).float()  # (question_num, question_num)

    output = output[:, :, 1:].unsqueeze(-2)  # (batch_size, seq_size, 1, question_num)

    laplacian_loss = output @ laplacian_matrix @ output.transpose(-1, -2)  # (batch_size, seq_size, 1)
    laplacian_loss = laplacian_loss.squeeze(-1).squeeze(-1)
    not_padding_num = padding_mask.sum().item()
    laplacian_loss = laplacian_loss * padding_mask.float()
    return laplacian_loss.sum() / (not_padding_num * question_num)
